import tiktoken
from .tokenizer import Tokenizer

class GPT2Tokenizer(Tokenizer):
    def __init__(self):
        self.enc = tiktoken.get_encoding("gpt2")
        self.int_to_str = {
            i: self.enc.decode([i]) for i in range(self.enc.n_vocab)
        }
        self.str_to_int = {v: k for k, v in self.int_to_str.items()}
        super().__init__(self.str_to_int)

    def encode(self, text):
        return self.enc.encode(text, allowed_special={"<|endoftext|>"})

    def decode(self, tokens):
        return self.enc.decode(tokens)

    def get_vocab(self):
        return self.int_to_str

    def n_vocab(self):
        return self.enc.n_vocab
